#!/usr/bin/env python3
# D12 — E/B Phase & Polarization — self-contained present-act engine (stdlib only)
# Control: pure boolean/ordinal predicates over integer shell/sector membership.
# Readouts: diagnostics only (phase lags, polarization); RNG unused (ties-only policy irrelevant here).

import argparse, csv, hashlib, json, math, os, sys
from datetime import datetime, timezone
from typing import Dict, List, Tuple

def utc_timestamp() -> str:
    return datetime.now(timezone.utc).strftime("%Y-%m-%dT%H-%M-%SZ")

def ensure_dirs(root: str, subdirs: List[str]) -> None:
    for d in subdirs:
        p = os.path.join(root, d)
        os.makedirs(p, exist_ok=True)

def write_text(path: str, text: str) -> None:
    with open(path, "w", encoding="utf-8") as f:
        f.write(text)

def sha256_file(path: str) -> str:
    h = hashlib.sha256()
    with open(path, "rb") as f:
        for chunk in iter(lambda: f.read(1 << 20), b""):
            h.update(chunk)
    return h.hexdigest()

def json_dump(path: str, obj: dict) -> None:
    with open(path, "w", encoding="utf-8") as f:
        json.dump(obj, f, indent=2, sort_keys=True)

def modS(x: int, S: int) -> int:
    return ((x % S) + S) % S

def build_rs_counts(N: int, cx: int, cy: int, S: int) -> Dict[int, List[int]]:
    """
    Precompute counts for each (shell r, sector s) pair.
    Sector index is an integer 0..S-1 from atan2 binning (diagnostic math; membership is integer).
    """
    rs_counts: Dict[int, List[int]] = {}
    for y in range(N):
        for x in range(N):
            dx = x - cx
            dy = y - cy
            r2 = dx*dx + dy*dy
            r = math.isqrt(r2)
            if r == 0:
                s = 0
            else:
                ang = math.atan2(dy, dx)  # [-pi, pi]
                if ang < 0:
                    ang += 2.0*math.pi
                s = int((ang / (2.0*math.pi)) * S)
                if s >= S:
                    s = S-1
            if r not in rs_counts:
                rs_counts[r] = [0]*S
            rs_counts[r][s] += 1
    return rs_counts

def circular_best_lag_norm(E: List[float], B: List[float], max_lag: int) -> Tuple[int, float]:
    """
    Circular normalized cross-correlation.
    Returns (lag >=0, peak_norm_corr in [0,1]). Lag means B shifted forward by lag to best match E.
    """
    T = len(E)
    if T == 0 or len(B) != T:
        return 0, float("nan")
    import math
    sumE2 = sum(e*e for e in E)
    sumB2 = sum(b*b for b in B)
    if sumE2 == 0 or sumB2 == 0:
        return 0, 0.0
    best_lag, best_val = 0, -1.0
    for lag in range(0, max_lag+1):
        num = 0.0
        for t in range(T):
            num += E[t] * B[(t - lag) % T]
        val = num / math.sqrt(sumE2 * sumB2)
        if val > best_val:
            best_val, best_lag = val, lag
    return best_lag, best_val

def circular_mean_sector(values: List[int], S: int) -> float:
    """
    Circular mean on integers 0..S-1 -> returns float sector index (0..S).
    """
    if not values:
        return float("nan")
    cs = sum(math.cos(2.0*math.pi * v / S) for v in values)
    sn = sum(math.sin(2.0*math.pi * v / S) for v in values)
    ang = math.atan2(sn, cs)
    if ang < 0:
        ang += 2.0*math.pi
    return (ang / (2.0*math.pi)) * S

def circ_dist_sectors(a: float, b: float, S: int) -> float:
    """smallest absolute distance in sector units between angles a and b (both in 0..S)."""
    diff = abs(a - b)
    diff = diff % S
    return min(diff, S - diff)

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--manifest", required=True)
    ap.add_argument("--outdir", required=True)
    args = ap.parse_args()

    ts = utc_timestamp()
    root = os.path.abspath(args.outdir)
    ensure_dirs(root, ["config", "outputs/metrics", "outputs/audits", "outputs/run_info", "logs"])

    with open(args.manifest, "r", encoding="utf-8") as f:
        manifest = json.load(f)
    manifest_path = os.path.join(root, "config", "manifest_d12.json")
    json_dump(manifest_path, manifest)

    # Env snapshot
    env_text = [f"utc={ts}", f"os={os.name}", f"cwd={os.getcwd()}", f"python={sys.version.split()[0]}"]
    write_text(os.path.join(root, "logs", "env.txt"), "\n".join(env_text))

    # Parameters
    N = int(manifest["grid"]["N"])
    cx = int(manifest["grid"].get("cx", N//2))
    cy = int(manifest["grid"].get("cy", N//2))
    H  = int(manifest["H"])
    S  = int(manifest["sectors"]["S"])
    step = int(manifest["sectors"].get("step_per_tick", 1))
    phi0 = int(manifest["source"]["phi0_sector"])  # starting orientation
    win_half = int(manifest["source"].get("ang_window_half", 0))  # 0 => exactly one sector wide

    rmin = int(manifest["source"]["radial_band"]["r_min"])
    rmax = int(manifest["source"]["radial_band"]["r_max"])

    det_sectors = [int(x) for x in manifest["detectors"]["sector_ids"]]
    lag_expected = int(manifest["acceptance"].get("lag_expected_ticks", 1))
    lag_spread_max = int(manifest["acceptance"].get("lag_spread_max_ticks", 0))
    corr_min = float(manifest["acceptance"].get("min_corrpeak_norm", 0.60))
    pol_tol = float(manifest["acceptance"].get("pol_tol_sectors", 1.0))

    # Precompute counts for (r,s)
    rs_counts = build_rs_counts(N, cx, cy, S)

    # Build E/B time series per detector sector
    T = H
    E_ts = {j: [0.0]*T for j in det_sectors}
    B_ts = {j: [0.0]*T for j in det_sectors}

    for t in range(T):
        k = modS(phi0 + step*t, S)              # current orientation sector
        prev = modS(k - step, S)                # previous sector (cw step)
        # sectors enabled in orientation window
        window_sectors = {modS(k + d, S) for d in range(-win_half, win_half+1)}

        # E: For sectors in window, we count cells in [rmin..rmax] that lie in that exact sector.
        # This is present-act control via integer membership; counts are diagnostics.
        for j in det_sectors:
            if j in window_sectors:
                # Sum counts across radial band for this sector j
                cnt = 0
                for r in range(rmin, rmax+1):
                    if r in rs_counts:
                        cnt += rs_counts[r][j]
                E_ts[j][t] = float(cnt)
            else:
                E_ts[j][t] = 0.0

        # B: direction-aware exit boundary at sector j happens when prev == j (we step cw to k)
        for j in det_sectors:
            B_ts[j][t] = 1.0 if (prev == j) else 0.0

    # Phase lags and correlations per detector sector
    sector_results = []
    lags = []
    corrs = []
    first_peaks_phi_candidates = []
    for j in det_sectors:
        lag, corr = circular_best_lag_norm(E_ts[j], B_ts[j], max_lag=S-1)
        lags.append(lag); corrs.append(corr)
        # earliest E peak time (first tick with positive E)
        t_first = next((ti for ti, val in enumerate(E_ts[j]) if val > 0.0), None)
        # infer phi0 candidate from t_first: t_first ≡ (j - phi0) mod S  =>  phi0 ≡ (j - t_first) mod S
        if t_first is not None:
            phi_cand = modS(j - t_first, S)
            first_peaks_phi_candidates.append(phi_cand)
        sector_results.append({
            "sector": j, "best_lag": lag, "norm_corr": corr, "first_peak_t": t_first
        })

    # Polarization estimate from circular mean of candidates
    phi_hat = circular_mean_sector(first_peaks_phi_candidates, S) if first_peaks_phi_candidates else float("nan")
    pol_err = circ_dist_sectors(phi_hat, float(phi0), S) if not math.isnan(phi_hat) else float("nan")

    # Acceptance checks
    ok_lag_spread = (max(lags) - min(lags) <= lag_spread_max) if lags else False
    ok_lag_expected = all((lag == lag_expected) for lag in lags)
    ok_corr = all((c >= corr_min) for c in corrs)
    ok_pol = (not math.isnan(pol_err)) and (pol_err <= pol_tol)

    passed = bool(ok_lag_spread and ok_lag_expected and ok_corr and ok_pol)

    # Write metrics: per-sector time series
    metrics_csv = os.path.join(root, "outputs", "metrics", "d12_timeseries.csv")
    with open(metrics_csv, "w", newline="", encoding="utf-8") as f:
        w = csv.writer(f)
        header = ["tick"]
        for j in det_sectors:
            header += [f"E_s{j}", f"B_s{j}"]
        w.writerow(header)
        for t in range(T):
            row = [t]
            for j in det_sectors:
                row += [f"{E_ts[j][t]:.0f}", f"{B_ts[j][t]:.0f}"]
            w.writerow(row)

    # Audit JSON
    audit = {
        "sim": "D12_eb_phase_polarization",
        "S": S,
        "step_per_tick": step,
        "phi0_sector": phi0,
        "ang_window_half": win_half,
        "radial_band": {"r_min": rmin, "r_max": rmax},
        "detector_sectors": det_sectors,
        "sector_results": sector_results,
        "lags": lags,
        "corrs": corrs,
        "lag_expected": lag_expected,
        "lag_spread_max": lag_spread_max,
        "corr_min": corr_min,
        "phi_hat_sector": phi_hat,
        "pol_err_sectors": pol_err,
        "pass": passed,
        "manifest_hash": sha256_file(manifest_path)
    }
    json_dump(os.path.join(root, "outputs", "audits", "d12_audit.json"), audit)

    # Result line
    result_line = ("D12 PASS={p} lag*={lag}±{spread} corr_min={cm:.3f} phi_hat={ph:.2f} pol_err={pe:.2f}"
                   .format(p=passed, lag=(lags[0] if lags else None),
                           spread=(max(lags)-min(lags) if lags else None),
                           cm=(min(corrs) if corrs else float('nan')),
                           ph=(phi_hat if not math.isnan(phi_hat) else float('nan')),
                           pe=(pol_err if not math.isnan(pol_err) else float('nan'))))
    write_text(os.path.join(root, "outputs", "run_info", "result_line.txt"), result_line)
    print(result_line)

if __name__ == "__main__":
    main()
